package two;

import java.util.Scanner;

public class test10304 {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();
        int k = scanner.nextInt();
        if(n>=3&&n%2==1&&k>=1&&k<3){
            System.out.println(0);
            return;
        }
        if(n%2==0&&k==1){
            System.out.println(0);
            return;
        }
        int num = getNum(n,k);
        System.out.println(num);
    }

    public static int getNum(int n, int k) {
        if (n == 1)
            return k;
        else if (n == 2)
            return k * (k - 1);
        else if (n == 3)
            return k * (k - 1) * (k - 2);
        else
            return (k - 2) * getNum(n - 1, k) + (k - 1) * getNum(n - 2, k);
    }
}
